package org.example.utils.common;

import java.util.*;

public class RecommenderSystem {
    private Map<Integer, Map<Integer, Double>> userItemRatingTable;
    private int neighborhoodSize;

    public RecommenderSystem(Map<Integer, Map<Integer, Double>> userItemRatingTable, int neighborhoodSize) {
        this.userItemRatingTable = userItemRatingTable;
        this.neighborhoodSize = neighborhoodSize;
    }

    public Map<Integer, Double> recommendItems(int userId) {
        Map<Integer, Double> ratingTotalMap = new HashMap<>();
        Map<Integer, Double> weightTotalMap = new HashMap<>();

        Map<Double, Integer> similarityMap = new TreeMap<>(Collections.reverseOrder());

        for (Map.Entry<Integer, Map<Integer, Double>> userEntry : userItemRatingTable.entrySet()) {
            int neighborId = userEntry.getKey();
            if (neighborId != userId) {
                double similarity = calculateSimilarity(userItemRatingTable.get(userId), userItemRatingTable.get(neighborId));
                similarityMap.put(similarity, neighborId);
            }
        }

        int count = 0;
        for (Map.Entry<Double, Integer> similarityEntry : similarityMap.entrySet()) {
            int neighborId = similarityEntry.getValue();
            Map<Integer, Double> items = userItemRatingTable.get(neighborId);
            for (Map.Entry<Integer, Double> itemEntry : items.entrySet()) {
                int itemId = itemEntry.getKey();
                double rating = itemEntry.getValue();
                ratingTotalMap.put(itemId, ratingTotalMap.getOrDefault(itemId, 0.0) + similarityEntry.getKey() * rating);
                weightTotalMap.put(itemId, weightTotalMap.getOrDefault(itemId, 0.0) + similarityEntry.getKey());
            }
            count++;
            if (count >= neighborhoodSize) {
                break;
            }
        }

        Map<Integer, Double> recommendedItemScores = new HashMap<>();
        for (Map.Entry<Integer, Double> ratingTotalEntry : ratingTotalMap.entrySet()) {
            int itemId = ratingTotalEntry.getKey();
            double score = ratingTotalEntry.getValue() / weightTotalMap.get(itemId);
            recommendedItemScores.put(itemId, score);
        }
        return recommendedItemScores;
    }

    private double calculateSimilarity(Map<Integer, Double> user1, Map<Integer, Double> user2) {
        Set<Integer> commonItemIds = new HashSet<>(user1.keySet());
        commonItemIds.retainAll(user2.keySet());

        double numerator = 0.0;
        double denominator1 = 0.0;
        double denominator2 = 0.0;

        for (int itemId : commonItemIds) {
            numerator += user1.get(itemId) * user2.get(itemId);
            denominator1 += Math.pow(user1.get(itemId), 2);
            denominator2 += Math.pow(user2.get(itemId), 2);
        }

        double denominator = Math.sqrt(denominator1) * Math.sqrt(denominator2);

        if (denominator == 0) {
            return 0.0;
        } else {
            return numerator / denominator;
        }
    }
}